import os
import numpy as np
import matplotlib.pyplot as plt

import sys

from numpy.lib.function_base import disp
sys.path.insert(0, '/media/tan/_dde_data/projects/ginkgo_vfife/scripts')
from visualization import ResultFileSystem, TimeHistory


def trajectory(t, Im, moment, c=0):
    """ moving trajectory  horizontal projectile motion

    Args:
        t (1d-ndarray): time series, s
        Im (float): inertia mass, kg * m
        c (float)): damping coefficient

    Returns:
        (dx, vx, ax) - displace, velocity, accelerate
    """
    acc = moment / Im
    if abs(c) > 1e-6:
        x = acc / c * t + acc / (c**2) * (np.exp(-c * t) -1)
        vx = acc / c * (1 - np.exp(-c * t))
        ax = acc * np.exp(-c * t)
    else:
        ax = acc * np.ones_like(t)
        vx = ax * t
        x = ax * t ** 2 / 2.0
    return (x, vx, ax)


def compare(data, keys, item, figname):
    """ compare results of theroy and openVFIFE

    Args:
        data (dict): {"theory": {"c": [t, x, y, z]}
                      "vfife": {"c": [t, x, y, z]}}
        keys (list): ["c1", "c2", ...]
        item (str): "x", "y"
    """
    index = {"Ux": 1, "Uy": 2, "Uz": 3, "Rotx": 4, "Roty": 5, "Rotz": 6}
    ind = index[item]
    theory = data["theory"]
    vfife = data["vfife"]
    figsize = np.array([8, 6]) / 2.54
    fig, ax = plt.subplots(figsize=figsize, tight_layout=True)
    cnt = 0
    for key in keys:
        l1, = ax.plot(theory[key][:,0], theory[key][:,ind], c="black", lw=1.5)
        l2, = ax.plot(vfife[key][:,0], vfife[key][:,ind], c="red", lw=1.5,
                     ls="dashed")

        ax.text(vfife[key][-10 - cnt*10,0], vfife[key][-10 - cnt*10,ind],
                key,fontsize=8, ha="center", va="center", c="blue")
        cnt += 1
    ax.legend([l1, l2], ["theory", "openVFIFE"], loc="best")
    plt.savefig(figname, dpi=300)
    plt.close(fig)


if __name__ == "__main__":
    mass = 10
    clist = [0.5 * i for i in range(3)]
    dt = 0.1
    t = np.arange(0, 10+dt, dt)
    keys = ["%.1f" %(i) for i in clist]
    print(clist, keys)
    # data container
    displace = {"vfife":{}, "theory":{}}
    velocity = {"vfife":{}, "theory":{}}
    accelerate = {"vfife":{}, "theory":{}}

    # load data from openVFIFE result
    # dir = "/home/tan/Desktop/test/vfife_examples/example1/"
    dir = "/media/tan/_dde_data/projects/ginkgo_vfife/tests/particle"
    for key in keys:
        fname = os.path.join(dir, "c" + key)
        print(fname)
        file = ResultFileSystem(fname)
        th = TimeHistory(file)
        displace["vfife"][key] = th.extract_particle_motion(1, "displace")
        velocity["vfife"][key] = th.extract_particle_motion(1, "velocity")
        accelerate["vfife"][key] = th.extract_particle_motion(1, "accelerate")
        # drop out t = 0

    # theory results
    for i in range(len(clist)):
        x, vx, ax = trajectory(t, mass, 10, clist[i])
        y, vy, ay = trajectory(t, 2*mass, 10, clist[i])
        z, vz, az = trajectory(t, 3*mass, 10, clist[i])
        rx, vrx, arx = trajectory(t, 1, 0.8333333, clist[i])
        ry, vry, ary = trajectory(t, 1, 0.8333333, clist[i])
        rz, vrz, arz = trajectory(t, 1, 0.8333333, clist[i])

        displace["theory"][keys[i]] = np.hstack((t.reshape(-1,1),
            x.reshape(-1,1), y.reshape(-1,1), z.reshape(-1,1),
            rx.reshape(-1,1), ry.reshape(-1,1), rz.reshape(-1,1)))
        velocity["theory"][keys[i]] = np.hstack((t.reshape(-1,1),
            vx.reshape(-1,1), vy.reshape(-1,1), vz.reshape(-1,1),
            vrx.reshape(-1,1), vry.reshape(-1,1), vrz.reshape(-1,1)))
        accelerate["theory"][keys[i]] = np.hstack((t.reshape(-1,1),
            ax.reshape(-1,1), ay.reshape(-1,1), az.reshape(-1,1),
            arx.reshape(-1,1), ary.reshape(-1,1), arz.reshape(-1,1)))


    cwd = os.getcwd()
    items = ["Ux", "Uy", "Uz", "Rotx", "Roty", "Rotz"]
    for item in items:
        figname = os.path.join(cwd, "displace_" + item + ".svg")
        compare(displace, keys, item, figname)

        figname = os.path.join(cwd, "veloctiy_" + item + ".svg")
        compare(velocity, keys, item, figname)

        figname = os.path.join(cwd, "accelerate_" + item + ".svg")
        compare(accelerate, keys, item, figname)



